{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction\n",
    "In this notebook we demonstrates how to use `SparseEmbedding` and `SparseAdam` to obtain stroger performance with sparse gradient."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare Environment\n",
    "Before you start with Apis delivered by bigdl-nano, you have to make sure BigDL-Nano is correctly installed for TensorFlow. If not, please follow [this](../../../../../docs/readthedocs/source/doc/Nano/Overview/nano.md) to set up your environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bigdl.nano.tf.keras import Model, Sequential"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the Data\n",
    "We demonstrate with imdb_reviews, a large dataset of movie reviews."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow_datasets as tfds\n",
    "(raw_train_ds, raw_val_ds, raw_test_ds), info = tfds.load(\n",
    "    \"imdb_reviews\",\n",
    "    split=['train[:80%]', 'train[80%:]', 'test'],\n",
    "    as_supervised=True,\n",
    "    batch_size=32,\n",
    "    shuffle_files=False,\n",
    "    with_info=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's preview a few samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-07-24 22:25:55.972229: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<style  type=\"text/css\" >\n",
       "</style><table id=\"T_96345b40_0b9f_11ed_912f_0242ac110002\" ><thead>    <tr>        <th class=\"blank level0\" ></th>        <th class=\"col_heading level0 col0\" >label</th>        <th class=\"col_heading level0 col1\" >text</th>    </tr></thead><tbody>\n",
       "                <tr>\n",
       "                        <th id=\"T_96345b40_0b9f_11ed_912f_0242ac110002level0_row0\" class=\"row_heading level0 row0\" >0</th>\n",
       "                        <td id=\"T_96345b40_0b9f_11ed_912f_0242ac110002row0_col0\" class=\"data row0 col0\" >0 (neg)</td>\n",
       "                        <td id=\"T_96345b40_0b9f_11ed_912f_0242ac110002row0_col1\" class=\"data row0 col1\" >This was an absolutely terrible movie. Don&#x27;t be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie&#x27;s ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor&#x27;s like Christopher Walken&#x27;s good name. I could barely sit through it.</td>\n",
       "            </tr>\n",
       "            <tr>\n",
       "                        <th id=\"T_96345b40_0b9f_11ed_912f_0242ac110002level0_row1\" class=\"row_heading level0 row1\" >1</th>\n",
       "                        <td id=\"T_96345b40_0b9f_11ed_912f_0242ac110002row1_col0\" class=\"data row1 col0\" >0 (neg)</td>\n",
       "                        <td id=\"T_96345b40_0b9f_11ed_912f_0242ac110002row1_col1\" class=\"data row1 col1\" >I have been known to fall asleep during films, but this is usually due to a combination of things including, really tired, being warm and comfortable on the sette and having just eaten a lot. However on this occasion I fell asleep because the film was rubbish. The plot development was constant. Constantly slow and boring. Things seemed to happen, but with no explanation of what was causing them or why. I admit, I may have missed part of the film, but i watched the majority of it and everything just seemed to happen of its own accord without any real concern for anything else. I cant recommend this film at all.</td>\n",
       "            </tr>\n",
       "            <tr>\n",
       "                        <th id=\"T_96345b40_0b9f_11ed_912f_0242ac110002level0_row2\" class=\"row_heading level0 row2\" >2</th>\n",
       "                        <td id=\"T_96345b40_0b9f_11ed_912f_0242ac110002row2_col0\" class=\"data row2 col0\" >0 (neg)</td>\n",
       "                        <td id=\"T_96345b40_0b9f_11ed_912f_0242ac110002row2_col1\" class=\"data row2 col1\" >Mann photographs the Alberta Rocky Mountains in a superb fashion, and Jimmy Stewart and Walter Brennan give enjoyable performances as they always seem to do. &lt;br /&gt;&lt;br /&gt;But come on Hollywood - a Mountie telling the people of Dawson City, Yukon to elect themselves a marshal (yes a marshal!) and to enforce the law themselves, then gunfighters battling it out on the streets for control of the town? &lt;br /&gt;&lt;br /&gt;Nothing even remotely resembling that happened on the Canadian side of the border during the Klondike gold rush. Mr. Mann and company appear to have mistaken Dawson City for Deadwood, the Canadian North for the American Wild West.&lt;br /&gt;&lt;br /&gt;Canadian viewers be prepared for a Reefer Madness type of enjoyable howl with this ludicrous plot, or, to shake your head in disgust.</td>\n",
       "            </tr>\n",
       "            <tr>\n",
       "                        <th id=\"T_96345b40_0b9f_11ed_912f_0242ac110002level0_row3\" class=\"row_heading level0 row3\" >3</th>\n",
       "                        <td id=\"T_96345b40_0b9f_11ed_912f_0242ac110002row3_col0\" class=\"data row3 col0\" >1 (pos)</td>\n",
       "                        <td id=\"T_96345b40_0b9f_11ed_912f_0242ac110002row3_col1\" class=\"data row3 col1\" >This is the kind of film for a snowy Sunday afternoon when the rest of the world can go ahead with its own business as you descend into a big arm-chair and mellow for a couple of hours. Wonderful performances from Cher and Nicolas Cage (as always) gently row the plot along. There are no rapids to cross, no dangerous waters, just a warm and witty paddle through New York life at its best. A family film in every sense and one that deserves the praise it received.</td>\n",
       "            </tr>\n",
       "            <tr>\n",
       "                        <th id=\"T_96345b40_0b9f_11ed_912f_0242ac110002level0_row4\" class=\"row_heading level0 row4\" >4</th>\n",
       "                        <td id=\"T_96345b40_0b9f_11ed_912f_0242ac110002row4_col0\" class=\"data row4 col0\" >1 (pos)</td>\n",
       "                        <td id=\"T_96345b40_0b9f_11ed_912f_0242ac110002row4_col1\" class=\"data row4 col1\" >As others have mentioned, all the women that go nude in this film are mostly absolutely gorgeous. The plot very ably shows the hypocrisy of the female libido. When men are around they want to be pursued, but when no &quot;men&quot; are around, they become the pursuers of a 14 year old boy. And the boy becomes a man really fast (we should all be so lucky at this age!). He then gets up the courage to pursue his true love.</td>\n",
       "            </tr>\n",
       "    </tbody></table>"
      ],
      "text/plain": [
       "   label                                               text\n",
       "0      0  b\"This was an absolutely terrible movie. Don't...\n",
       "1      0  b'I have been known to fall asleep during film...\n",
       "2      0  b'Mann photographs the Alberta Rocky Mountains...\n",
       "3      1  b'This is the kind of film for a snowy Sunday ...\n",
       "4      1  b'As others have mentioned, all the women that..."
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tfds.as_dataframe(raw_train_ds.unbatch().take(5), info)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare the Data\n",
    "In particular, we remove \\<br /> tags."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.keras.layers import TextVectorization\n",
    "import string\n",
    "import re\n",
    "\n",
    "def custom_standardization(input_data):\n",
    "    lowercase = tf.strings.lower(input_data)\n",
    "    stripped_html = tf.strings.regex_replace(lowercase, \"<br />\", \" \")\n",
    "    return tf.strings.regex_replace(\n",
    "        stripped_html, f\"[{re.escape(string.punctuation)}]\", \"\"\n",
    "    )\n",
    "\n",
    "max_features = 20000\n",
    "embedding_dim = 128\n",
    "sequence_length = 500\n",
    "\n",
    "vectorize_layer = TextVectorization(\n",
    "    standardize=custom_standardization,\n",
    "    max_tokens=max_features,\n",
    "    output_mode=\"int\",\n",
    "    output_sequence_length=sequence_length,\n",
    ")\n",
    "\n",
    "# Let's make a text-only dataset (no labels):\n",
    "text_ds = raw_train_ds.map(lambda x, y: x)\n",
    "# Let's call `adapt`:\n",
    "vectorize_layer.adapt(text_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def vectorize_text(text, label):\n",
    "    text = tf.expand_dims(text, -1)\n",
    "    return vectorize_layer(text), label\n",
    "\n",
    "\n",
    "# Vectorize the data.\n",
    "train_ds = raw_train_ds.map(vectorize_text)\n",
    "val_ds = raw_val_ds.map(vectorize_text)\n",
    "test_ds = raw_test_ds.map(vectorize_text)\n",
    "\n",
    "# Do async prefetching / buffering of the data for best performance on GPU.\n",
    "train_ds = train_ds.cache().prefetch(buffer_size=10)\n",
    "val_ds = val_ds.cache().prefetch(buffer_size=10)\n",
    "test_ds = test_ds.cache().prefetch(buffer_size=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Custom Model\n",
    "`bigdl.nano.tf.keras.Embedding` is a slightly modified version of tf.keras.Embedding layer, this embedding layer only applies regularizer to the output of the embedding layer, so that the gradient to embeddings is sparse. `bigdl.nano.tf.optimzers.Adam` is a variant of the Adam optimizer that handles sparse updates more efficiently. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we create two models, one using normal Embedding layer and Adam optimizer, the other using `SparseEmbedding` and `SparseAdam`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras import layers\n",
    "from bigdl.nano.tf.keras.layers import Embedding\n",
    "from bigdl.nano.tf.optimizers import SparseAdam\n",
    "\n",
    "def make_backbone():\n",
    "    inputs = tf.keras.Input(shape=(None, embedding_dim))\n",
    "    x = layers.Dropout(0.5)(inputs)\n",
    "    x = layers.Conv1D(128, 7, padding=\"valid\", activation=\"relu\", strides=3)(x)\n",
    "    x = layers.Conv1D(128, 7, padding=\"valid\", activation=\"relu\", strides=3)(x)\n",
    "    x = layers.GlobalMaxPooling1D()(x)\n",
    "    x = layers.Dense(128, activation=\"relu\")(x)\n",
    "    x = layers.Dropout(0.5)(x)\n",
    "    predictions = layers.Dense(1, activation=\"sigmoid\", name=\"predictions\")(x)\n",
    "\n",
    "    model = Model(inputs, predictions)\n",
    "    return model\n",
    "\n",
    "def make_model():\n",
    "    inputs = tf.keras.Input(shape=(None,), dtype=\"int64\")\n",
    "    x = layers.Embedding(max_features, embedding_dim)(inputs)\n",
    "    predictions = make_backbone()(x)\n",
    "    model = Model(inputs, predictions)\n",
    "    model.compile(loss=\"binary_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n",
    "    return model\n",
    "\n",
    "def make_model_mod():\n",
    "    inputs = tf.keras.Input(shape=(None,), dtype=\"int64\")\n",
    "    x = Embedding(max_features, embedding_dim)(inputs)\n",
    "    predictions = make_backbone()(x)\n",
    "    model = Model(inputs, predictions)\n",
    "    model.compile(loss=\"binary_crossentropy\", optimizer=SparseAdam(), metrics=[\"accuracy\"])\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/3\n",
      "625/625 [==============================] - 23s 36ms/step - loss: 0.4908 - accuracy: 0.7282 - val_loss: 0.2945 - val_accuracy: 0.8780\n",
      "Epoch 2/3\n",
      "625/625 [==============================] - 22s 35ms/step - loss: 0.2239 - accuracy: 0.9130 - val_loss: 0.2962 - val_accuracy: 0.8804\n",
      "Epoch 3/3\n",
      "625/625 [==============================] - 22s 35ms/step - loss: 0.1143 - accuracy: 0.9609 - val_loss: 0.4011 - val_accuracy: 0.8706\n",
      "782/782 [==============================] - 5s 6ms/step - loss: 0.4511 - accuracy: 0.8512\n"
     ]
    }
   ],
   "source": [
    "from time import time\n",
    "model = make_model()\n",
    "\n",
    "# Shorten fitting time during test\n",
    "import os\n",
    "epochs = int(os.environ.get('epochs', 3))\n",
    "\n",
    "start = time()\n",
    "model.fit(train_ds, validation_data=val_ds, epochs=epochs)\n",
    "fit_time = time() - start\n",
    "\n",
    "his = model.evaluate(test_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/3\n",
      "625/625 [==============================] - 14s 22ms/step - loss: 0.5348 - accuracy: 0.6857 - val_loss: 0.3320 - val_accuracy: 0.8600\n",
      "Epoch 2/3\n",
      "625/625 [==============================] - 14s 23ms/step - loss: 0.2479 - accuracy: 0.8983 - val_loss: 0.2969 - val_accuracy: 0.8770\n",
      "Epoch 3/3\n",
      "625/625 [==============================] - 14s 22ms/step - loss: 0.1336 - accuracy: 0.9503 - val_loss: 0.3671 - val_accuracy: 0.8808\n",
      "782/782 [==============================] - 4s 4ms/step - loss: 0.3961 - accuracy: 0.8662\n"
     ]
    }
   ],
   "source": [
    "model = make_model_mod()\n",
    "\n",
    "start = time()\n",
    "model.fit(train_ds, validation_data=val_ds, epochs=epochs)\n",
    "fit_time_mod = time() - start\n",
    "\n",
    "his_mod = model.evaluate(test_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "|        Precision     |    Fit Time(s)    | Accuracy(%) |\n",
      "|        Benchmark     |       67.36       |    85.12    |\n",
      "|     Model_modified   |       42.56       |    86.62    |\n",
      "|      Improvement(%)  |       36.81       |     1.76    |\n",
      "\n"
     ]
    }
   ],
   "source": [
    "template = \"\"\"\n",
    "|        Precision     |    Fit Time(s)    | Accuracy(%) |\n",
    "|        Benchmark     |       {:5.2f}       |    {:5.2f}    |\n",
    "|     Model_modified   |       {:5.2f}       |    {:5.2f}    |\n",
    "|      Improvement(%)  |       {:5.2f}       |    {:5.2f}    |\n",
    "\"\"\"\n",
    "summary = template.format(\n",
    "    fit_time, his[1] * 100,\n",
    "    fit_time_mod, his_mod[1] * 100,\n",
    "    (1 - fit_time_mod/fit_time) * 100,  (his_mod[1]/his[1] - 1) * 100\n",
    ")\n",
    "print(summary)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.10 ('nanoTensorflow')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "422a89592e6073af4a7c8f61fc6e0436178ff93c3ec3830f9ef03054dd0a1cc1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
